{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch on TPUs: Fast Neural Style Transfer",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dbOTXWKBnBiP",
        "colab_type": "text"
      },
      "source": [
        "![alt text](https://i.imgur.com/ipYa6Q8.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h1mYGqkc1kqv",
        "colab_type": "text"
      },
      "source": [
        "## PyTorch on TPUs: Fast Neural Style Transfer\n",
        "\n",
        "This notebook lets you run a pre-trained fast neural style transfer network implemented in PyTorch on a Cloud TPU. You can combine pictures and styles to create fun new images. \n",
        "\n",
        "You can learn more about fast neural style transfer from its implementation [here](https://github.com/pytorch/examples/tree/master/fast_neural_style) or the original paper, available [here](https://arxiv.org/abs/1603.08155).\n",
        "\n",
        "This notebook loads PyTorch and stores the network on your Google drive. After this automated setup process (it takes a couple minutes) you can put in a link to an image and see your style applied in seconds!\n",
        "\n",
        "You can find more examples of running PyTorch on TPUs [here](https://github.com/pytorch/xla/tree/master/contrib/colab), including tutorials on how to run PyTorch on hundreds of TPUs with Google Cloud Platform. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YofXQrnxmf5r",
        "colab_type": "text"
      },
      "source": [
        "### Installs PyTorch & Loads the Networks\n",
        "(This may take a couple minutes.)\n",
        "\n",
        "Fast neural style transfer networks use the same architecture but different weights to encode their styles. This notebook creates four fast neural style transfer networks: \"rain princess,\" \"candy,\" \"mosaic,\" and \"udnie.\" You can apply these styles below."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sg7i8Wk6Iblu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C7fLX2EdIjL_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-df27v_-Iqrb",
        "colab_type": "text"
      },
      "source": [
        "### Only run the below commented cell if you would like a nightly release"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6vDR6iDcfULe",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# VERSION = \"20200325\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "# !python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPJVqAKyml5W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from google.colab.patches import cv2_imshow\n",
        "import cv2\n",
        "import sys\n",
        "\n",
        "# Configures repo in local colab fs\n",
        "REPO_DIR = '/demo'\n",
        "%mkdir -p \"$REPO_DIR\"\n",
        "%cd \"$REPO_DIR\" \n",
        "%rm -rf examples\n",
        "!git clone https://github.com/pytorch/examples.git \n",
        "%cd \"$REPO_DIR/examples/fast_neural_style\"\n",
        "\n",
        "# Download pretrained weights for styles\n",
        "!python download_saved_models.py\n",
        "%cd \"$REPO_DIR/examples/fast_neural_style/neural_style\"\n",
        "\n",
        "\n",
        "## Creates pre-trained style networks\n",
        "import argparse\n",
        "import os\n",
        "import sys\n",
        "import time\n",
        "import re\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch.optim import Adam\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu\n",
        "import utils\n",
        "from transformer_net import TransformerNet\n",
        "from vgg import Vgg16\n",
        "\n",
        "# Acquires the XLA device (a TPU core)\n",
        "device = xm.xla_device()\n",
        "\n",
        "# Loads pre-trained weights\n",
        "rain_princess_path = '../saved_models/rain_princess.pth'\n",
        "candy_path = '../saved_models/candy.pth'\n",
        "mosaic_path = '../saved_models/mosaic.pth'\n",
        "udnie_path = '../saved_models/udnie.pth'\n",
        "\n",
        "# Loads the pre-trained weights into the fast neural style transfer\n",
        "# network architecture and puts the network on the Cloud TPU core.\n",
        "def load_style(path):\n",
        "  with torch.no_grad():\n",
        "    model = TransformerNet()\n",
        "    state_dict = torch.load(path)\n",
        "    # filters deprecated running_* keys from the checkpoint\n",
        "    for k in list(state_dict.keys()):\n",
        "        if re.search(r'in\\d+\\.running_(mean|var)$', k):\n",
        "            del state_dict[k]\n",
        "    model.load_state_dict(state_dict)\n",
        "    return model.to(device)\n",
        "\n",
        "# Creates each fast neural style transfer network\n",
        "rain_princess = load_style(rain_princess_path)\n",
        "candy = load_style(candy_path)\n",
        "mosaic = load_style(mosaic_path)\n",
        "udnie = load_style(udnie_path)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j1w1G4AcWw9f",
        "colab_type": "text"
      },
      "source": [
        "## Try it out!\n",
        "\n",
        "The next cell loads and display an image from a URL. This image is styled by the following cell. You can re-run these two cells as often as you like to style multiple images.\n",
        "\n",
        "Start by copying and pasting an image URL here (or use the default corgi)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EozMXwIV9iOJ",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "#@markdown ### Image URL (right click -> copy image address):\n",
        "content_image_url = 'https://cdn.pixabay.com/photo/2019/06/11/15/42/corgi-face-4267312__480.jpg' #@param {type:\"string\"}\n",
        "content_image = 'content.jpg'\n",
        "!wget -O \"$content_image\" \"$content_image_url\"\n",
        "RESULT_IMAGE = '/tmp/result.jpg'\n",
        "!rm -f \"$RESULT_IMAGE\"\n",
        "img = cv2.imread(content_image, cv2.IMREAD_UNCHANGED)\n",
        "\n",
        "content_image = utils.load_image(content_image, scale=None)\n",
        "content_transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Lambda(lambda x: x.mul(255))\n",
        "    ])\n",
        "content_image = content_transform(content_image)\n",
        "content_image = content_image.unsqueeze(0).to(device)\n",
        "\n",
        "cv2_imshow(img)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e0vHw-aHoG-s",
        "colab_type": "text"
      },
      "source": [
        "To style your image simply uncomment the style you wish to apply below and run the cell!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z0j9i4EWctbU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "with torch.no_grad():\n",
        "  output = rain_princess(content_image)\n",
        "  # output = candy(content_image)\n",
        "  # output = mosaic(content_image)\n",
        "  # output = udnie(content_image)\n",
        "\n",
        "\n",
        "utils.save_image(RESULT_IMAGE, output[0].cpu())\n",
        "img = cv2.imread(RESULT_IMAGE, cv2.IMREAD_UNCHANGED)\n",
        "cv2_imshow(img)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
